""" feedforward_net_2d.py
    Feed-forward 2D convolutional neural network.

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import torch
from torch import nn
from torchvision.models import resnet18, ResNet18_Weights, efficientnet_v2_m, EfficientNet_V2_M_Weights, efficientnet_v2_s, EfficientNet_V2_S_Weights, efficientnet_v2_l, EfficientNet_V2_L_Weights
from torchvision.models import resnet34, ResNet34_Weights
from torchvision import models
import torchvision.transforms as T

from .blocks import BasicBlock2D as BasicBlock

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class FeedForwardNet(nn.Module):
    """Modified Residual Network model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, max_iters=8, group_norm=False):
        super().__init__()

        self.width = int(width)
        self.recall = recall
        self.group_norm = group_norm

        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3, stride=1, padding=1, bias=False)

        if self.recall:
            self.recall_layer = nn.Conv2d(width+in_channels, width, kernel_size=3,
                                          stride=1, padding=1, bias=False)
        else:
            self.recall_layer = nn.Sequential()

        self.feedforward_layers = nn.ModuleList()
        for _ in range(max_iters):
            internal_block = []
            for j in range(len(num_blocks)):
                internal_block.append(self._make_layer(block, width, num_blocks[j], stride=1))
            self.feedforward_layers.append(nn.Sequential(*internal_block))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3, stride=1, padding=1, bias=False)

        self.iters = max_iters
        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.head = nn.Sequential(head_conv1, nn.ReLU(), head_conv2, nn.ReLU(), head_conv3)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, iters_elapsed=0, **kwargs):
        # assert (iters_elapsed + iters_to_do) <= self.iters
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i, layer in enumerate(self.feedforward_layers[iters_elapsed:iters_elapsed+iters_to_do]):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
                interim_thought = self.recall_layer(interim_thought)
            interim_thought = layer(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if iters_to_do > self.iters:
            # fill in the rest with the last output
            all_outputs[:, self.iters:] = out.unsqueeze(1).repeat(1, iters_to_do - self.iters, 1, 1, 1)

        if self.training:
            return out, interim_thought
        else:
            return all_outputs


def feedforward_net_2d(width, **kwargs):
    return FeedForwardNet(BasicBlock, [2], width, in_channels=kwargs["in_channels"],
                          recall=False, max_iters=kwargs["max_iters"])


def feedforward_net_recall_2d(width, **kwargs):
    return FeedForwardNet(BasicBlock, [2], width, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"])


def feedforward_net_gn_2d(width, **kwargs):
    return FeedForwardNet(BasicBlock, [2], width, in_channels=kwargs["in_channels"],
                          recall=False, max_iters=kwargs["max_iters"], group_norm=True)


def feedforward_net_recall_gn_2d(width, **kwargs):
    return FeedForwardNet(BasicBlock, [2], width, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"], group_norm=True)



class FeedForwardNetMaxPool(nn.Module):
    """Modified Residual Network model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, max_iters=8, group_norm=False):
        super().__init__()

        self.width = int(width)
        self.recall = recall
        self.group_norm = group_norm

        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3, stride=1, padding=1, bias=False)

        if self.recall:
            self.recall_layer = nn.Conv2d(width+in_channels, width, kernel_size=3,
                                          stride=1, padding=1, bias=False)
        else:
            self.recall_layer = nn.Sequential()

        self.feedforward_layers = nn.ModuleList()
        for _ in range(max_iters):
            internal_block = []
            for j in range(len(num_blocks)):
                internal_block.append(self._make_layer(block, width, num_blocks[j], stride=1))
            self.feedforward_layers.append(nn.Sequential(*internal_block))


        head_pool = nn.AdaptiveMaxPool2d(output_size=1)

        head_conv1 = nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(width, 32, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(32, output_size, kernel_size=3, stride=1, padding=1, bias=False)

        self.iters = max_iters
        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.head = nn.Sequential(head_conv1, nn.ReLU(), head_pool, head_conv2, nn.ReLU(), head_conv3)

        self.output_size = output_size

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, iters_elapsed=0, **kwargs):
        assert (iters_elapsed + iters_to_do) <= self.iters
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        for i, layer in enumerate(self.feedforward_layers[iters_elapsed:iters_elapsed+iters_to_do]):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
                interim_thought = self.recall_layer(interim_thought)
            interim_thought = layer(interim_thought)
            out = self.head(interim_thought).view(x.size(0), self.output_size)
            all_outputs[:, i] = out

        if self.training:
            return out, interim_thought
        else:
            return all_outputs

def feedforward_net_2d_out10(width, **kwargs):
    return FeedForwardNetMaxPool(BasicBlock, [2], width, output_size=10, in_channels=kwargs["in_channels"],
                          recall=False, max_iters=kwargs["max_iters"])


def feedforward_net_recall_2d_out10(width, **kwargs):
    return FeedForwardNetMaxPool(BasicBlock, [2], width, output_size=10, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"])


def feedforward_net_2d_out4(width, **kwargs):
    return FeedForwardNetMaxPool(BasicBlock, [2], width, output_size=4, in_channels=kwargs["in_channels"],
                          recall=False, max_iters=kwargs["max_iters"])

def feedforward_net_recall_2d_out4(width, **kwargs):
    return FeedForwardNetMaxPool(BasicBlock, [2], width, output_size=4, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"])



class FeedForwardNetMaxPoolEnd(nn.Module):
    """Modified Residual Network model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, max_iters=8, group_norm=False):
        super().__init__()

        self.width = int(width)
        self.recall = recall
        self.group_norm = group_norm

        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3, stride=1, padding=1, bias=False)

        if self.recall:
            self.recall_layer = nn.Conv2d(width+in_channels, width, kernel_size=3,
                                          stride=1, padding=1, bias=False)
        else:
            self.recall_layer = nn.Sequential()

        self.feedforward_layers = nn.ModuleList()
        for _ in range(max_iters):
            internal_block = []
            for j in range(len(num_blocks)):
                internal_block.append(self._make_layer(block, width, num_blocks[j], stride=1))
            self.feedforward_layers.append(nn.Sequential(*internal_block))


        head_pool = nn.AdaptiveMaxPool2d(output_size=1)

        head_conv1 = nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(width, 32, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(32, output_size, kernel_size=3, stride=1, padding=1, bias=False)

        self.iters = max_iters
        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.head = nn.Sequential(head_conv1, nn.ReLU(), head_conv2, nn.ReLU(), head_conv3, head_pool)

        self.output_size = output_size

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, iters_elapsed=0, **kwargs):
        assert (iters_elapsed + iters_to_do) <= self.iters
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        for i, layer in enumerate(self.feedforward_layers[iters_elapsed:iters_elapsed+iters_to_do]):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
                interim_thought = self.recall_layer(interim_thought)
            interim_thought = layer(interim_thought)
            out = self.head(interim_thought).view(x.size(0), self.output_size)
            all_outputs[:, i] = out

        if self.training:
            return out, interim_thought
        else:
            return all_outputs

def feedforward_net_2d_out10_end(width, **kwargs):
    return FeedForwardNetMaxPoolEnd(BasicBlock, [2], width, output_size=10, in_channels=kwargs["in_channels"],
                          recall=False, max_iters=kwargs["max_iters"])


def feedforward_net_recall_2d_out10_end(width, **kwargs):
    return FeedForwardNetMaxPoolEnd(BasicBlock, [2], width, output_size=10, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"])

def make_torchvision_class(model_module, weights_module):
    class PytorchModel(nn.Module):
        """Modified Residual Network model class"""

        def __init__(self, output_size,pretrained=True):
            super().__init__()

            # self.width = int(width)
            # self.recall = recall
            # self.group_norm = group_norm
            self.output_size = output_size

            # proj_conv = nn.Conv2d(in_channels, width, kernel_size=3, stride=1, padding=1, bias=False)

            if pretrained:
                weights = weights_module.DEFAULT #weights = ResNet18_Weights.DEFAULT
                self.model = model_module(weights=weights)
                if model_module == resnet18:
                    self.model.fc = nn.Linear(512, output_size)
                elif model_module == models.shufflenet_v2_x1_0:
                    self.model.fc = nn.Linear(1024, output_size)
                elif model_module == efficientnet_v2_m or model_module == efficientnet_v2_l:
                    self.model.classifier = nn.Linear(1280, output_size)
                elif model_module == efficientnet_v2_s:
                    self.model.classifier = nn.Linear(1280, output_size)

                elif model_module == models.vgg16_bn:
                    self.model.classifier = nn.Linear(512, output_size)
                
                else:
                    raise NotImplementedError()

            else:
                self.model = model_module(weights=None,num_classes=output_size)


            self.transforms = torch.nn.Sequential(
                                T.Resize(224,antialias=False,interpolation=T.InterpolationMode.NEAREST),
                                # T.RandomHorizontalFlip(p=0.3),
                            )


        def forward(self, x, iters_to_do, interim_thought=None, iters_elapsed=0, **kwargs):
            # assert (iters_elapsed + iters_to_do) <= self.iters
            x= self.transforms(x)
            
            all_outputs = torch.zeros((x.size(0), 1, self.output_size)).to(x.device)

            out = self.model(x)

            all_outputs[:, 0] = out


            if self.training:
                return out, interim_thought
            else:
                return all_outputs

    return PytorchModel

Resnet18 = make_torchvision_class(resnet18, ResNet18_Weights)
Resnet34 = make_torchvision_class(resnet34, ResNet34_Weights)
EfficientNetL = make_torchvision_class(efficientnet_v2_l, EfficientNet_V2_L_Weights)
EfficientNetM = make_torchvision_class(efficientnet_v2_m, EfficientNet_V2_M_Weights)
EfficientNetS = make_torchvision_class(efficientnet_v2_s, EfficientNet_V2_S_Weights)
shufflenet_v2_x1_0 = make_torchvision_class(models.shufflenet_v2_x1_0, models.ShuffleNet_V2_X1_0_Weights)
vgg16 = make_torchvision_class(models.vgg16, models.VGG16_Weights)


def resnet18_out4(width, **kwargs):
    return Resnet18(output_size=4, pretrained=False)

def resnet34_out4(width, **kwargs):
    return Resnet34(output_size=4, pretrained=False)

def resnet34_out3(width, **kwargs):
    return Resnet34(output_size=3, pretrained=False)

def resnet18_out4_pretrained(width, **kwargs):
    return Resnet18(output_size=4,pretrained=True)

def eff_l_out4(width, **kwargs):
    return EfficientNetL(output_size=4,pretrained=False)

def eff_l_out4_pretrained(width, **kwargs):
    return EfficientNetL(output_size=4,pretrained=True)

def eff_m_out4(width, **kwargs):
    return EfficientNetM(output_size=4,pretrained=False)

def eff_m_out4_pretrained(width, **kwargs):
    return EfficientNetM(output_size=4,pretrained=True)

def eff_s_out4(width, **kwargs):
    return EfficientNetS(output_size=4,pretrained=False)

def eff_s_out4_pretrained(width, **kwargs):
    return EfficientNetS(output_size=4,pretrained=True)

def shufflenet_x1_out4(width, **kwargs):
    return shufflenet_v2_x1_0(output_size=4,pretrained=False)

def shufflenet_x1_out4_pretrained(width, **kwargs):
    return shufflenet_v2_x1_0(output_size=4,pretrained=True)

def vgg16_out4(width, **kwargs):
    return vgg16(output_size=4, pretrained=False)

def vgg16_out4_pretrained(width, **kwargs):
    return vgg16(output_size=4,pretrained=True)

# Testing
if __name__ == "__main__":
    net = feedforward_net_recall_2d(width=5, in_channels=3, max_iters=5)
    print(net)
    x_test = torch.rand(4 * 3 * 5 * 5).reshape([4, 3, 5, 5])
    out_test, _ = net(x_test)
    print(out_test.shape)
    out_test, _ = net(x_test, n=2, k=2)
    print(out_test.shape)

    net.eval()
    outputs = net(x_test)
    print(outputs.shape)
    outputs = net(x_test, n=2, k=2)
    print(outputs.shape)
